Predicting Pathogen from RNAseq data


In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from math import pow
from sklearn.feature_selection import RFECV, RFE
from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV
from sklearn.svm import LinearSVC, LinearSVR

random_state = 7

In [2]:
patient_groups=["control", "viral", "bacterial", "fungal"]
group_id = lambda name: patient_groups.index(name)

X = pd.DataFrame.from_csv("combineSV_WTcpmtable_v2.txt", sep="\s+").T
# + [group_id("fungal")] * 10 \
y = [group_id("bacterial")] * 29 \
    + [group_id("viral")] * 42 \
    + [group_id("control")] * 61
    
# Drop the fungal patients
fv = range(29+42, 29+42+10)
X = X.drop(X.index[fv])

print "Complete data set has %d samples and %d features." % (X.shape[0], X.shape[1])
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, 
                                                    random_state=random_state, 
                                                    stratify=y)
print "Training set has %d samples. Testing set has %d samples." % (len(X_train), len(X_test))


Complete data set has 132 samples and 25342 features.
Training set has 105 samples. Testing set has 27 samples.

In [3]:
def print_gridsearch_results(clf):
    print("Best: %f using %s" % (clf.best_score_, clf.best_params_))
    means = clf.cv_results_['mean_test_score']
    stds = clf.cv_results_['std_test_score']
    params = clf.cv_results_['params']
    for mean, stdev, param in zip(means, stds, params):
        print("%f (%f) with: %r" % (mean, stdev, param))

In [4]:
from sklearn.multiclass import OneVsRestClassifier
from sklearn.linear_model import LogisticRegression

#parameters={'estimator__C': [pow(2, i) for i in xrange(-25, 4, 1)]}
est = LogisticRegression(class_weight="balanced", C=pow(2, -23))
#est = LinearSVR(C=pow(2, -23))
#clf = RFE(estimator=lr, step=0.01, n_features_to_select=1, verbose=1)
clf = RFECV(cv=4, estimator=est, n_jobs=6, scoring='neg_mean_squared_error', step=.001, verbose=0)
#clf = GridSearchCV(rfe, parameters, scoring='accuracy', n_jobs=8, cv=3, verbose=1)
clf.fit(X_train, y_train)
#print_gridsearch_results(clf)


Out[4]:
RFECV(cv=4,
   estimator=LogisticRegression(C=1.19209289551e-07, class_weight='balanced', dual=False,
          fit_intercept=True, intercept_scaling=1, max_iter=100,
          multi_class='ovr', n_jobs=1, penalty='l2', random_state=None,
          solver='liblinear', tol=0.0001, verbose=0, warm_start=False),
   n_jobs=6, scoring='neg_mean_squared_error', step=0.001, verbose=0)

In [5]:
from IPython.core.display import display, HTML

rfe_features = []
for (rank, name) in zip(clf.ranking_, X.columns):
    if rank == 1:
        rfe_features.append(name)

s="""
<h2>List of %d genes found by RFE</h2>
<p>Note: the NCBI link will open the target in a new window or tab.</p>
<table>
""" % (clf.n_features_)

ncbi_url = lambda gene: "https://www.ncbi.nlm.nih.gov/gene/?term=%s[Gene Name] AND Human[Organism]" % (gene)

s += "<tr>"
for (i, gene) in enumerate(rfe_features):
    if not i % 10:
        s += "</tr><tr>"
    s += """
    <td><a target=\"_blank\" href=\"%s\">%s</a></td>
    """ %(ncbi_url(gene), gene)
s += "</tr></table>"

display(HTML(s))





In [6]:
best_estimator = clf

print("Optimal number of features : %d" % best_estimator.n_features_)
print("Recursive Feature Elimination (RFE) eliminated %d features" % (X.shape[1] - best_estimator.n_features_))

# Plot number of features VS. cross-validation scores
plt.figure()
plt.xlabel("Number of feature subsets selected")
plt.ylabel("Cross validation score")
plt.plot(range(1, len(best_estimator.grid_scores_) + 1), best_estimator.grid_scores_)
plt.show()


Optimal number of features : 167
Recursive Feature Elimination (RFE) eliminated 25175 features

In [7]:
%matplotlib inline
from learning_curves import plot_learning_curve
import matplotlib.pyplot as plt
from sklearn.model_selection import StratifiedShuffleSplit

def create_learning_curve(title, model):
    cv = StratifiedShuffleSplit(n_splits=2, test_size=0.3, random_state=random_state)                                     
    plot_learning_curve(model, title, X, y, (0.2, 1.01), cv=cv, n_jobs=1)
    
create_learning_curve("Learning Curves Full Feature Set", est)
create_learning_curve("Learning Curves RFE Feature Set", clf.estimator_)

plt.show()


None

Make predictions based on the model


In [11]:
from classification_metrics import classification_metrics

est.fit(X_train, y_train)
est_predicted = est.predict(X_test)
print "Full Metrics"
classification_metrics(y_test, est_predicted, patient_groups)
rfe_predicted = clf.predict(X_test)
print "-" * 80
print "\nRFE Metrics"
classification_metrics(y_test, rfe_predicted, patient_groups)


Full Metrics
Accuracy was 66.67%

             precision    recall  f1-score   support

    control       0.70      0.58      0.64        12
      viral       0.71      0.56      0.63         9
  bacterial       0.60      1.00      0.75         6

avg / total       0.68      0.67      0.66        27

Confusion Matrix: cols = predictions, rows = actual

                       control          viral      bacterial         fungal
        control              7              2              3              0
          viral              3              5              1              0
      bacterial              0              0              6              0
         fungal              0              0              0              0
--------------------------------------------------------------------------------

RFE Metrics
Accuracy was 74.07%

             precision    recall  f1-score   support

    control       0.88      0.58      0.70        12
      viral       0.78      0.78      0.78         9
  bacterial       0.60      1.00      0.75         6

avg / total       0.78      0.74      0.74        27

Confusion Matrix: cols = predictions, rows = actual

                       control          viral      bacterial         fungal
        control              7              2              3              0
          viral              1              7              1              0
      bacterial              0              0              6              0
         fungal              0              0              0              0

Review model predictions


In [12]:
probs = pd.DataFrame(clf.predict_proba(X_test))
probstrs = lambda vals: ["%.2f" % (p*100) for p in vals]

d = {"Predicted": [patient_groups[i] for i in rfe_predicted],
     "Actual": [patient_groups[i] for i in y_test],
     "Prob. control": probstrs(probs[0]),
     "Prob. viral": probstrs(probs[1]),
     "Prob. bacteria": probstrs(probs[2])}

patient_df = pd.DataFrame(d, index=X_test.index)
patient_df.sort_values(by="Actual")


Out[12]:
Actual Predicted Prob. bacteria Prob. control Prob. viral
MN_223 bacterial bacterial 65.10 0.00 34.90
MN_140 bacterial bacterial 96.85 0.11 3.04
MN_304 bacterial bacterial 75.90 10.18 13.91
MN_224 bacterial bacterial 46.74 21.19 32.07
MN_324 bacterial bacterial 90.39 9.61 0.00
MNC.473 bacterial bacterial 65.08 25.68 9.23
MNC.571 control control 2.84 76.51 20.65
MN_368 control bacterial 95.01 4.73 0.26
MN_366 control control 2.68 81.32 16.00
MNC.213 control control 25.97 42.77 31.26
MNC.294 control control 0.01 99.98 0.01
MNC.533 control control 1.60 90.78 7.61
MNC.631 control bacterial 99.55 0.40 0.05
MNC.675 control control 0.39 85.34 14.27
MNC.116 control viral 25.64 3.90 70.46
MNC.054 control viral 20.20 8.06 71.74
MNC.151 control control 0.43 91.56 8.00
MNC.291 control bacterial 70.70 5.79 23.51
MNC.234 viral viral 34.94 4.32 60.74
MNC.033 viral viral 18.34 33.67 48.00
MN_171 viral viral 4.29 42.05 53.66
MN_282 viral viral 31.02 15.04 53.94
MNC.215 viral control 0.59 71.32 28.09
MNC.098 viral viral 5.28 24.66 70.06
MNC.176 viral viral 0.98 11.19 87.83
MNC.331 viral viral 4.97 46.34 48.69
MNC.015 viral bacterial 57.11 26.78 16.12

Review patients the model misclassified


In [13]:
patient_df[patient_df["Predicted"] != patient_df["Actual"]]


Out[13]:
Actual Predicted Prob. bacteria Prob. control Prob. viral
MNC.054 control viral 20.20 8.06 71.74
MNC.116 control viral 25.64 3.90 70.46
MNC.631 control bacterial 99.55 0.40 0.05
MNC.291 control bacterial 70.70 5.79 23.51
MNC.015 viral bacterial 57.11 26.78 16.12
MNC.215 viral control 0.59 71.32 28.09
MN_368 control bacterial 95.01 4.73 0.26